import torch
import numpy as np

def prep_grad(x):
    x_flat = torch.unsqueeze(x, 0).flatten()
    dim = x.shape
    d = x_flat.shape[0]
    return x_flat, dim, d

@torch.no_grad()
def clip_noise(x, h, noise):
    """
    :param x: vector to clip
    :param h: parameter
    :return: clipped vector
    """
    nrm = torch.norm(x)
    if nrm > h:
        x.data = torch.mul(x.data, h / nrm)
    x.data = torch.add(x.data, torch.randn_like(x), alpha=noise)
    return x
    
@torch.no_grad()
def clip_noise_wrap(h=0.1, noise=0.1):
    def cl(x):
        return clip_noise(x, h=h, noise=noise)
    return cl
    
@torch.no_grad()
def clip(x, h):
    """
    :param x: vector to clip
    :param h: parameter
    :return: clipped vector
    """
    nrm = torch.norm(x)
    if nrm > h:
        x.data = torch.mul(x.data, h / nrm)
    return x
    
@torch.no_grad()
def clip_wrap(h=0.1):
    def cl(x):
        return clip(x, h=h)
    return cl